import torch
from lietorch import SE3
from modules.droid_net import DroidNet
from depth_video import DepthVideo
from motion_filter import MotionFilter
from track_frontend import TrackFrontend
from track_backend import TrackBackend
from util.trajectory_filler import PoseTrajectoryFiller
from util.utils import load_config

from collections import OrderedDict
from torch.multiprocessing import Process, Queue
from gs_backend import GSBackEnd
from pgo_buffer import PGOBuffer

import numpy as np
import os
from util.poses import to_se3_vec

def parse_extra_params(args, config):
    args.init_g = np.array(config['IMU']['init_g'])
    args.init_bg = np.array(config['IMU']['init_bg'])
    args.init_ba = np.array(config['IMU']['init_ba'])
    if 'livo2' in args.imagedir.lower():    
        args.Tcb_np = np.loadtxt(os.path.join(args.imagedir, '..', 'extrinsics.txt'))        
    else:
        args.Tcb_np = np.array(config['IMU']['Tcb_np'])
    if args.imus is not None:
        if config['IMU']['imu_in_nanoseconds']:
            args.imus[:, 0] /= 1e9
        args.imus[:, 0] += config['IMU']['imu_time_offset']
    args.Tcb = SE3(torch.tensor(to_se3_vec(args.Tcb_np), dtype=torch.float, device='cuda')[None, None])
    return args

class VIGS:
    def __init__(self, args):
        super(VIGS, self).__init__()
        self.load_weights(args.weights)
        self.config = config = load_config(args.config)
        args = parse_extra_params(args, config)
        self.args = args
        self.gsmapping = args.gsmapping
        self.images = {}

        # store images, depth, poses, intrinsics (shared between processes)
        self.video = DepthVideo(config, args, args.image_size, args.buffer)

        # filter incoming frames so that there is enough motion
        self.filterx = MotionFilter(self.net, self.video, config, config["Tracking"]["disable_mono"])

        # frontend process
        self.frontend = TrackFrontend(self.net, self.video, config["Tracking"]["frontend"], args)

        # backend process
        self.backend = TrackBackend(self.net, self.video, config["Tracking"]["backend"])

        # 3dgs
        self.gs = GSBackEnd(config, self.args.output, args, args.gsvis)
        if self.gsmapping:
            self.video.gs = self.gs
            self.gs.video = self.video
        # post processor - fill in poses for non-keyframes
        self.traj_filler = PoseTrajectoryFiller(self.net, self.video)

        # # visualizer
        # if args.droidvis:
        #     from util.droid_visualization import droid_visualization
        #     self.visualizer = Process(target=droid_visualization, args=(self.video,))
        #     self.visualizer.start()
    
        # visualizer
        if args.rerunvis:
            from util.droid_visualization_rerun import droid_visualization_rerun
            self.visualizer = Process(
                target=droid_visualization_rerun,
                args=(self.video,),
                kwargs=dict(
                    web_port=9876,                           # port the node will serve on
                    record_path=f"{self.args.output}/rerun_stream.rrd"  # optional
                )
            )
            self.visualizer.start()
        # global PGBA backend
        self.pgba = config["Tracking"]["pgba"]["active"]
        # if self.pgba:
        #     self.video.pgobuf = PGOBuffer(self.net, self.video, self.frontend, config["Tracking"]["pgba"], args)
        #     self.LC_data_queue = Queue()
        #     self.video.pgobuf.set_LC_data_queue(self.LC_data_queue)
        #     self.mp_backend = Process(target=self.video.pgobuf.spin)
        #     self.mp_backend.start()
        
        # Switch to threading for PGBA to avoid forked process
        import threading
        import queue

        if self.pgba:
            self.video.pgobuf = PGOBuffer(
                self.net, self.video, self.frontend,
                config["Tracking"]["pgba"], args
            )

            # Use a thread-safe queue for communication *within the same process*
            self.LC_data_queue = queue.Queue()
            self.video.pgobuf.set_LC_data_queue(self.LC_data_queue)

            # Run PGBA in a background thread (same CUDA context)
            self.mp_backend = threading.Thread(
                target=self.video.pgobuf.spin,
                daemon=True,
            )
            self.mp_backend.start()
            
        
        self.pgba = config["Tracking"]["pgba"]["active"]
                    
    def load_weights(self, weights):
        """ load trained model weights """
        self.net = DroidNet()
        state_dict = OrderedDict([
            (k.replace("module.", ""), v) for (k, v) in torch.load(weights).items()])
        state_dict["update.weight.2.weight"] = state_dict["update.weight.2.weight"][:2]
        state_dict["update.weight.2.bias"] = state_dict["update.weight.2.bias"][:2]
        state_dict["update.delta.2.weight"] = state_dict["update.delta.2.weight"][:2]
        state_dict["update.delta.2.bias"] = state_dict["update.delta.2.bias"][:2]
        self.net.load_state_dict(state_dict)
        self.net.to("cuda:0").eval()
    
    def call_gs(self, viz_idx, dposes=None, dscale=None, final=False, update_idx=None):
        if not self.gsmapping:
            return
        data = {'viz_idx':  viz_idx.to(device='cpu'),
                'tstamp':   self.video.tstamp[viz_idx].to(device='cpu'),
                'poses':    self.video.poses[viz_idx].to(device='cpu'),
                'images':   self.video.images[viz_idx.cpu()],
                'normals':  self.video.normals[viz_idx.cpu()],
                'depths':   1./self.video.disps_up[viz_idx.cpu()].to(device='cpu'),
                'intrinsics':   self.video.intrinsics[viz_idx].to(device='cpu') * 8,
                'pose_updates':  dposes.to(device='cpu') if dposes is not None else None,
                'scale_updates': dscale.to(device='cpu') if dscale is not None else None,
                'update_idx': update_idx.to(device='cpu') if update_idx is not None else None}
        
        data['final']=final
        if 'livo2' in self.args.imagedir.lower():
            if viz_idx.max()<=8:
                return

        self.gs.process_track_data(data) 

    def track(self, t, tstamp, image, intrinsics=None, is_last=False):
        """ main thread - update map """

        with torch.no_grad():
            # t is the index, frame ID
            # tstamp is the actual time in seconds
            self.images[t] = image

            # check there is enough motion
            self.filterx.track(t, tstamp, image, intrinsics, is_last)

            # local bundle adjustment
            viz_idx = self.frontend(is_last=is_last)

        if len(viz_idx) and self.pgba:
            dposes, dscale, lcii, lcjj, local_ii, local_jj = self.video.pgobuf.run_pgba(self.LC_data_queue)
            if dposes is not None:  
                update_idx = torch.unique(torch.cat([lcii, lcjj]))
                self.call_gs(torch.arange(0, self.video.counter.value-1, device='cuda'), dposes[:-1], dscale[:-1], update_idx=update_idx)

        if len(viz_idx):
            self.call_gs(viz_idx)

    def terminate(self, inertial=False):
        """ terminate the visualization process, return poses [t, q] """
        del self.frontend
        del self.filterx

        poses_pre = self.video.poses[:self.video.counter.value].clone()
        self.backend(7, inertial=inertial)
        self.backend(12, inertial=inertial)
        del self.backend
        poses_pos = self.video.poses[:self.video.counter.value].clone()
        dposes = SE3(poses_pos) * SE3(poses_pre).inv()
        dscale = torch.ones(self.video.counter.value, 1)
        torch.cuda.empty_cache()

        # Final Color Refinement
        if self.gsmapping:
            #TODO: it optimize pose using re-rendering loss here, need do ablation study to see if it is necessary, current paper result is with this step
            self.call_gs(torch.arange(0, self.video.counter.value, device='cuda'), dposes, dscale, final=True)
            updated_poses = self.gs.finalize()
            self.video.poses[:self.video.counter.value] = torch.tensor(updated_poses[:,1:])
                    
        traj_full = self.traj_filler(self.images)
        if self.gsmapping:
            self.gs.eval_rendering(self.images, self.args.gtdepthdir, traj_full.matrix().data, self.video.tstamp[:self.video.counter.value].to(device='cpu'))
        return traj_full.inv().data.cpu().numpy()
